import pandas as pd
import sqlite3
import numpy as np
import torch
import gurobipy as gb
import graphviz
import random
from gurobipy import GRB
from itertools import product

table2_array = np.load('Database/table2_array.npy') - 1
table3_array = np.load('Database/table3_array.npy') - 1
table2_array_bin = np.zeros((25,5), dtype=int)
table3_array_bin = np.zeros((125,5), dtype=int)
table2_cur = np.array(table2_array).reshape((5,5))
table3_cur = np.array(table3_array).reshape((5,5,5))
for i in range(25):
    table2_array_bin[i][table2_array[i]] = 1
for i in range(125):
    table3_array_bin[i][table3_array[i]] = 1

def intersection(a,b):
    return list(set(a) & set(b))
def recode3(loc):
    x1 = loc // 25
    x2 = (loc % 25) // 5
    x3 = loc % 5
    return (x1,x2,x3)

def LoadData2(DataName):
    DB = sqlite3.connect('Database/analysisdata.db')
    dt = pd.read_sql_query("SELECT * from " + DataName, DB)
    dt = dt.set_index('usid')
    DB.close()
    # always change 0-1 encodig for 'econ_c', 'admin_c', 'health_c', 'soccap_c'
    # Therefore, outcome = 1 always means good result.
    for i in range(6):
        if np.corrcoef(dt.my4, dt.iloc[:,-6+i])[0,1] < -0.05:
            dt.iloc[:,-6+i] = 1 - dt.iloc[:,-6+i]
    # sort by safety score decreasing
    dt = dt.sort_values(by='my4', ascending=False)
    dt = pd.concat([dt.iloc[:, :19].round().astype(int), dt.iloc[:, 19:]], axis=1)
    dt['1f1g'] = table2_array[(dt.loc[:, 'mod1f_num'].values - 1) * 5 + (dt.loc[:, 'mod1g_num'].values - 1)]+1
    dt['1l1k'] = table2_array[(dt.loc[:, 'mod1l_num'].values - 1) * 5 + (dt.loc[:, 'mod1k_num'].values - 1)] + 1
    return dt


def GetRiskGain(TotalResult):
    # Create a 5*5 table indicating risk vector for change i to j, for each observation in the dataset.
    Risk_table = np.zeros(shape=(5, 5, TotalResult.shape[2]))
    Risk_table[1, 2, :] = (((TotalResult[2] < 0) * 1).sum(axis=0)) / TotalResult.shape[1]
    Risk_table[1, 3, :] = (((TotalResult[2] + TotalResult[3] < 0) * 1).sum(axis=0)).cpu() / TotalResult.shape[1]
    Risk_table[1, 4, :] = (((TotalResult[2] + TotalResult[3] + TotalResult[4] < 0) * 1).sum(axis=0)).cpu() / \
                          TotalResult.shape[1]
    Risk_table[2, 3, :] = (((TotalResult[3] < 0) * 1).sum(axis=0)).cpu() / TotalResult.shape[1]
    Risk_table[2, 4, :] = (((TotalResult[3] + TotalResult[4] < 0) * 1).sum(axis=0)).cpu() / TotalResult.shape[1]
    Risk_table[3, 4, :] = (((TotalResult[4] < 0) * 1).sum(axis=0)).cpu() / TotalResult.shape[1]
    for i in range(2, 5):
        for j in range(1, i):
            Risk_table[i, j, :] = 1 - Risk_table[j, i, :]
    Risk_table[0, 2:, :] = Risk_table[1, 2:, :]
    Risk_table[2:, 0, :] = Risk_table[2:, 1, :]
    Risk_table[0, 1, :] = 0.5
    Risk_table[1, 0, :] = 0.5

    Gain_table = np.zeros(shape=(5, 5, TotalResult.shape[2]))
    Gain_table[1, 2, :] = TotalResult[2].mean(axis=0).cpu()
    Gain_table[1, 3, :] = (TotalResult[2] + TotalResult[3]).mean(axis=0).cpu()
    Gain_table[1, 4, :] = (TotalResult[2] + TotalResult[3] + TotalResult[4]).mean(axis=0).cpu()
    Gain_table[2, 3, :] = TotalResult[3].mean(axis=0).cpu()
    Gain_table[2, 4, :] = (TotalResult[3] + TotalResult[4]).mean(axis=0).cpu()
    Gain_table[3, 4, :] = TotalResult[4].mean(axis=0).cpu()
    for i in range(2, 5):
        for j in range(1, i):
            Gain_table[i, j, :] = - Gain_table[j, i, :]
    Gain_table[0, 2:, :] = Gain_table[1, 2:, :]
    Gain_table[2:, 0, :] = Gain_table[2:, 1, :]
    Gain_table[0, 1, :] = -0.1
    Gain_table[1, 0, :] = -0.1

    return Risk_table, Gain_table

def GetRiskGain2(TotalResult, price=0):
    # here the gain is measure on the oucome scale
    # Create a 5*5 table indicating risk vector for change i to j, for each observation in the dataset.
    Risk_table = np.zeros(shape=(5, 5, TotalResult.shape[2]))
    Risk_table[1, 2, :] = (((TotalResult[2] < 0) * 1).sum(axis=0)) / TotalResult.shape[1]
    Risk_table[1, 3, :] = (((TotalResult[2] + TotalResult[3] < 0) * 1).sum(axis=0)).cpu() / TotalResult.shape[1]
    Risk_table[1, 4, :] = (((TotalResult[2] + TotalResult[3] + TotalResult[4] < 0) * 1).sum(axis=0)).cpu() / \
                          TotalResult.shape[1]
    Risk_table[2, 3, :] = (((TotalResult[3] < 0) * 1).sum(axis=0)).cpu() / TotalResult.shape[1]
    Risk_table[2, 4, :] = (((TotalResult[3] + TotalResult[4] < 0) * 1).sum(axis=0)).cpu() / TotalResult.shape[1]
    Risk_table[3, 4, :] = (((TotalResult[4] < 0) * 1).sum(axis=0)).cpu() / TotalResult.shape[1]
    for i in range(2, 5):
        for j in range(1, i):
            Risk_table[i, j, :] = 1 - Risk_table[j, i, :]
    Risk_table[0, 2:, :] = Risk_table[1, 2:, :]
    Risk_table[2:, 0, :] = Risk_table[2:, 1, :]
    Risk_table[0, 1, :] = 0.5
    Risk_table[1, 0, :] = 0.5

    mynorm = torch.distributions.normal.Normal(0,1)
    Gain_table = np.zeros(shape=(5, 5, TotalResult.shape[2]))
    Gain_table[1, 2, :] = mynorm.cdf(TotalResult[:3].sum(axis=0)).mean(axis=0) \
                          - mynorm.cdf(TotalResult[:2].sum(axis=0)).mean(axis=0)
    Gain_table[1, 3, :] = mynorm.cdf(TotalResult[:4].sum(axis=0)).mean(axis=0) \
                          - mynorm.cdf(TotalResult[:2].sum(axis=0)).mean(axis=0)
    Gain_table[1, 4, :] = mynorm.cdf(TotalResult[:5].sum(axis=0)).mean(axis=0) \
                          - mynorm.cdf(TotalResult[:2].sum(axis=0)).mean(axis=0)
    Gain_table[2, 3, :] = mynorm.cdf(TotalResult[:4].sum(axis=0)).mean(axis=0) \
                          - mynorm.cdf(TotalResult[:3].sum(axis=0)).mean(axis=0)
    Gain_table[2, 4, :] = mynorm.cdf(TotalResult[:5].sum(axis=0)).mean(axis=0) \
                          - mynorm.cdf(TotalResult[:3].sum(axis=0)).mean(axis=0)
    Gain_table[3, 4, :] = mynorm.cdf(TotalResult[:5].sum(axis=0)).mean(axis=0) \
                          - mynorm.cdf(TotalResult[:4].sum(axis=0)).mean(axis=0)
    for i in range(2, 5):
        for j in range(1, i):
            Gain_table[i, j, :] = - Gain_table[j, i, :]
    Gain_table[0, 2:, :] = Gain_table[1, 2:, :]
    Gain_table[2:, 0, :] = Gain_table[2:, 1, :]
    Gain_table[0, 1, :] = -0.000001
    Gain_table[1, 0, :] = -0.0000001

    return Risk_table, Gain_table



def add3waytable(m, start=None):
    temp1 = np.array([0,1,2,3,4], dtype=int)
    table3_binary = m.addMVar((125, 5), vtype='B')  # 3 way table. input with a,b,c corresponds to the 25*a+5*b+c element
    table3_integer = m.addMVar((125,), vtype='I')
    m.addConstrs(table3_binary[i, :] @ temp1 == table3_integer[i] for i in range(125))
    m.addConstrs(sum(table3_binary[i, :]) == 1 for i in range(125))
    m.addConstrs(table3_integer[25 * i + 5 * j + k] <= table3_integer[25 * i + 5 * j + (k + 1)] for i, j, k in
                 product(range(5), range(5), range(4)))
    m.addConstrs(table3_integer[25 * i + 5 * j + k] <= table3_integer[25 * i + 5 * (j + 1) + k] for i, j, k in
                 product(range(5), range(4), range(5)))
    m.addConstrs(table3_integer[25 * i + 5 * j + k] <= table3_integer[25 * (i + 1) + 5 * j + k] for i, j, k in
                 product(range(4), range(5), range(5)))
    m.addConstrs(table3_integer[i] - table3_array[i] <= 1 for i in range(125))
    m.addConstrs(table3_integer[i] - table3_array[i] >= -1 for i in range(125))
    if start is None:
        table3_binary.start = table3_array_bin
    else:
        table3_binary.start = start
    return table3_integer, table3_binary

def add2waytable(m, start=None):
    # 2 way table
    temp1 = np.array([0,1,2,3,4], dtype=int)
    table2_binary = m.addMVar((25, 5), vtype='B',
                              name='table2_binary')  # 3 way table. input with a,b,c corresponds to the 25*a+5*b+c element
    table2_integer = m.addMVar((25,), vtype='I', name='table2_integer')
    m.addConstrs(table2_binary[i, :] @ temp1 == table2_integer[i] for i in range(25))
    m.addConstrs(sum(table2_binary[i, :]) == 1 for i in range(25))
    m.addConstrs(table2_integer[5 * i + j] <= table2_integer[5 * i + (j + 1)] for i, j in product(range(5), range(4)))
    m.addConstrs(table2_integer[5 * i + j] <= table2_integer[5 * (i + 1) + j] for i, j in product(range(4), range(5)))
    m.addConstrs(table2_integer[i] - table2_array[i] <= 1 for i in range(25))
    m.addConstrs(table2_integer[i] - table2_array[i] >= -1 for i in range(25))
    if start is None:
        table2_binary.start = table2_array_bin
    else:
        table2_binary.start = start
    return table2_integer, table2_binary

def PL_4(X_input, Risk_table, Gain_table,R, Epsilon, penalty=0.005):
    # Input: N*3 matrix, output: N*1 vector
    N = X_input.shape[0]
    temp1 = np.array([0,1,2,3,4], dtype=int)
    m = gb.Model() # model for producing level4 score
    table3_integer, table3_binary = add3waytable(m)
    # Create action
    Cur_act_binary = table3_array_bin[X_input.iloc[:,0]*25+X_input.iloc[:,1]*5+X_input.iloc[:,2]*1, :] # N * 5 binary variable matrix indicating current action
    New_act_binary = table3_binary[X_input.iloc[:,0]*25+X_input.iloc[:,1]*5+X_input.iloc[:,2],:]
    risk_gurobi = sum(Cur_act_binary[i,:]@ Risk_table[:,:,i] @ New_act_binary[i,:] for i in range(N))
    gain_gurobi = sum(Cur_act_binary[i,:]@ Gain_table[:,:,i] @ New_act_binary[i,:] for i in range(N))/N
    Penalty = (table3_integer@table3_integer- table3_integer@table3_array*2 + table3_array@table3_array)*(-penalty)
    m.addConstr(risk_gurobi <= Epsilon * R)
    m.setObjective(gain_gurobi+Penalty, GRB.MAXIMIZE)
    table3_binary.start = table3_array_bin
    m.optimize()
    return m, table3_integer


class CD_234_shared2:
    def __init__(self, Xlist, Risk_table, Gain_table, R, Epsilon, penalty_scale):
        self.index_2 = None
        self.X4_input = Xlist[0]
        self.X3a_input = Xlist[1]
        self.X3b_input = Xlist[2]
        self.X3c_input = Xlist[3]
        self.X2a_input = Xlist[4]
        self.X2b_input = Xlist[5]
        self.X2c_input = Xlist[6]
        self.X2d_input = Xlist[7]
        self.X2e_input = Xlist[8]
        self.X2f_input = Xlist[9]
        self.connectivity3 = []
        self.get_connectivity()
        self.Risk_table = Risk_table
        self.Gain_table = Gain_table
        self.R = R
        self.N = self.X4_input.shape[0]
        self.Epsilon = Epsilon
        self.penalty_scale = penalty_scale
        self.tb3 = np.zeros((125),dtype=int)
        self.tb2 = np.zeros((25), dtype=int)
        self.tb3[:] = table3_array[:]
        self.tb2[:] = table2_array[:]
        self.cur_4 = table3_array[
                            self.X4_input.iloc[:, 0] * 25 + self.X4_input.iloc[:, 1] * 5 + self.X4_input.iloc[:, 2] * 1]

    def predict(self,X_input):
        # X_input is the 19 dimensional input of 0-4 scores
        dt = pd.DataFrame(X_input, dtype=int)
        dt['1f1g'] = table2_array[dt.iloc[:, 5].values * 5 + dt.iloc[:, 6].values ]
        dt['1l1k'] = table2_array[dt.iloc[:, 11].values * 5 + dt.iloc[:, 10].values ]
        New_2a = self.tb2[dt.iloc[:, 0].values * 5 + dt.iloc[:,1].values * 1]
        New_2b = self.tb3[dt.iloc[:, 4].values * 25 + dt.iloc[:,3].values*5+ dt.iloc[:,19].values * 1]
        New_2c = self.tb2[dt.iloc[:, 7].values * 5 + dt.iloc[:,8].values * 1]
        New_2d = self.tb3[dt.iloc[:, 9].values * 25 + dt.iloc[:,20].values*5+ dt.iloc[:,6].values * 1]
        New_2e = self.tb3[dt.iloc[:, 14].values * 25 + dt.iloc[:,13].values*5+ dt.iloc[:,15].values * 1]
        New_2f = self.tb3[dt.iloc[:, 17].values * 25 + dt.iloc[:,16].values*5+ dt.iloc[:,18].values * 1]
        New_3a = self.tb3[New_2a*25+New_2b*5+dt.iloc[:,2].values] #2a,2b,1c
        New_3b = self.tb3[New_2d*25+New_2c*5+dt.iloc[:,12].values] #2d,2c,1m
        New_3c = self.tb2[New_2e*5+New_2f] #2e,2f
        New_4 = self.tb3[New_3a * 25 + New_3b * 5 + New_3c]
        return New_4

    def ComputeRiskGain(self):
        New_2a = self.tb2[self.X2a_input['V1']*5+self.X2a_input['V2']]
        New_2b = self.tb3[self.X2b_input['V1']*25+self.X2b_input['V2']*5+self.X2b_input['V3']]
        New_2c = self.tb2[self.X2c_input['V1']*5+self.X2c_input['V2']]
        New_2d = self.tb3[self.X2d_input['V1']*25+self.X2d_input['V2']*5+self.X2d_input['V3']]
        New_2e = self.tb3[self.X2e_input['V1']*25+self.X2e_input['V2']*5+self.X2e_input['V3']]
        New_2f = self.tb3[self.X2f_input['V1']*25+self.X2f_input['V2']*5+self.X2f_input['V3']]
        New_3a = self.tb3[New_2a*25+New_2b*5+self.X3a_input['V3']] #2a,2b,1c
        New_3b = self.tb3[New_2d*25+New_2c*5+self.X3b_input['V3']] #2d,2c,1m
        New_3c = self.tb2[New_2e*5+New_2f] #2e,2f
        New_4 = self.tb3[New_3a*25+New_3b*5+New_3c]
        tempindex = np.stack(
            [self.cur_4, New_4, np.array(range(self.N))], axis=1)
        risk = self.Risk_table[tuple(zip(*tempindex))].sum()
        gain = self.Gain_table[tuple(zip(*tempindex))].sum()
        return risk/self.N,gain/self.N

    def PDP(self, X_input):
        # X_input is the 19 dimensional input of 0-4 scores
        temp = np.empty_like(X_input)
        temp[:] = X_input[:]
        result = np.empty(shape=(19, 5))
        for i in range(19):
            for j in range(5):
                temp[:,i] = j
                result[i,j] = self.predict(temp).mean()
                temp[:,i] = X_input[:,i]
        return result

    def CheckGraph2(self, loc):
        UpperBound = 100
        LowerBound = -100
        x1 = loc // 5
        x2 = loc % 5
        Larger_posi = []
        Smaller_posi = []
        if x1 < 4:
            Larger_posi.append(loc + 5)
        if x2 < 4:
            Larger_posi.append(loc + 1)
        if x1 > 0:
            Smaller_posi.append(loc - 5)
        if x2 > 0:
            Smaller_posi.append(loc - 1)
        for i in Larger_posi:
            if i>=0 and i<25:
                UpperBound = min(UpperBound, self.tb2[i])
        for i in Smaller_posi:
            if i>=0 and i<25:
                LowerBound = max(LowerBound, self.tb2[i])
        return LowerBound, UpperBound

    def CheckGraph3(self, loc):
        UpperBound = 100
        LowerBound = -100
        x1 = loc //25
        x2 = (loc%25)//5
        x3 = loc%5
        Larger_posi = []
        Smaller_posi = []
        if x1 < 4:
            Larger_posi.append(loc+25)
        if x2<4:
            Larger_posi.append(loc+5)
        if x3<4:
            Larger_posi.append(loc+1)
        if x1>0:
            Smaller_posi.append(loc-25)
        if x2>0:
            Smaller_posi.append(loc-5)
        if x3>0:
            Smaller_posi.append(loc-1)

        for i in Larger_posi:
            if i>=0 and i<125:
                UpperBound = min(UpperBound, self.tb3[i])
        for i in Smaller_posi:
            if i>=0 and i<125:
                LowerBound = max(LowerBound, self.tb3[i])
        return LowerBound, UpperBound

    def GetIndex(self, size2, size3):
        index_2 = np.random.choice(25, size2, replace=False)
        index_3 = np.random.choice(125, size3, replace=False)
        index_2_range = []
        index_3_range = []
        for i in index_2:
            lower, upper = self.CheckGraph2(i)
            index_2_range.append(range(max(0, lower), min(4, upper) + 1))  # append [lower, upper]
        for j in index_3:
            lower, upper = self.CheckGraph3(j)
            index_3_range.append(range(max(0, lower), min(4, upper) + 1))  # append [lower, upper]
        index_2_product = product(*index_2_range)
        index_3_product = product(*index_3_range)
        return index_2, index_3, np.random.permutation(list(index_2_product)), np.random.permutation(list(index_3_product))

    def onestep_rs(self, size2 = 1, size3 = 5, regularization=0):
        index_2, index_3, index_2_product, index_3_product = self.GetIndex(size2, size3)
        Risk, Gain = self.ComputeRiskGain()
        result = [self.tb2[index_2],self.tb3[index_3], Risk,Gain]
        for x in index_2_product:
            for y in index_3_product:
                self.tb2[index_2] = x
                self.tb3[index_3] = y
                Risk, Gain=self.ComputeRiskGain()
                if Risk > self.Epsilon:
                    continue
                elif Gain - regularization*Risk > result[-1] - regularization * result[-2]:
                    result = [x,y,Risk,Gain]
        self.tb2[index_2] = result[0]
        self.tb3[index_3] = result[1]

    def onestep_sb(self, size2=1, size3=1, Burst_size=10, regularization=0):
        Cur_risk_Gain = self.ComputeRiskGain()
        burst = [np.zeros((25),dtype=int), np.zeros((125), dtype=int), Cur_risk_Gain[0], Cur_risk_Gain[1]]
        burst[0][:] = self.tb2[:]
        burst[1][:] = self.tb3[:]

        for ite in range(Burst_size):
            while True:
                index_2, index_3, index_2_product, index_3_product = self.GetIndex(size2, size3)
                if index_2_product.__len__()>1 or index_3_product.__len__()>1:
                    break
            flag = False
            for x in index_2_product:
                for y in index_3_product:
                    x = random.choice(index_2_product)
                    y = random.choice(index_3_product)
                    self.tb2[index_2] = x
                    self.tb3[index_3] = y
                    Risk, Gain=self.ComputeRiskGain()
                    if Risk > self.Epsilon:
                        continue
                    elif Gain - regularization*Risk > burst[-1]-regularization*burst[-2]:
                        burst[0][:] = self.tb2[:]
                        burst[1][:] = self.tb3[:]
                        burst[2] = Risk
                        burst[3] = Gain
                    flag = True
                    break
                if flag:
                    break

        self.tb2[:] = burst[0][:]
        self.tb3[:] = burst[1][:]

    def get_connectivity(self):
        if self.connectivity3.__len__()>0:
            print('already initialized')
            return
        for loc in range(125):
            Larger_posi = []
            Smaller_posi = []
            x1 = loc //25
            x2 = (loc%25)//5
            x3 = loc%5
            if x1 < 4:
                Larger_posi.append(loc+25)
            if x2<4:
                Larger_posi.append(loc+5)
            if x3<4:
                Larger_posi.append(loc+1)
            if x1>0:
                Smaller_posi.append(loc-25)
            if x2>0:
                Smaller_posi.append(loc-5)
            if x3>0:
                Smaller_posi.append(loc-1)
            self.connectivity3.append([Larger_posi, Smaller_posi])

    def get_topological_sort_tb3(self):
        # rando select one border
        border = np.random.randint(4,size=1) # border and border+1 are the border we want to consider
        considered_loc = np.where((self.tb3==border)|(self.tb3==border+1))[0]
        lower_loc = np.where(self.tb3<border)[0]
        higher_loc = np.where(self.tb3>border+1)[0]
        myque = []
        considered_num = considered_loc.__len__()
        considered_loc_set = set(considered_loc)
        temp_order = [[[],[]]]*125
        for i in considered_loc:
            temp_order[i] = [intersection(self.connectivity3[i][0], considered_loc_set),
                intersection(self.connectivity3[i][1], considered_loc_set)]

        for i in considered_loc:
            flag = True # is lowest element
            if temp_order[i][1].__len__()==0:
                myque.append(i)
        smaller_area = []
        for ite in range(considered_num):
            # print(myque[-1],'before', myque.__len__())
            random.shuffle(myque)
            # print(myque[-1],'after')
            cur = myque.pop()
            for j in temp_order[cur][0]:
                if j not in myque and j not in smaller_area:
                    temp_order[j][1].remove(cur)
                if temp_order[j][1].__len__()==0:
                    myque.append(j)
            smaller_area.append(cur)
        # get random spanning tree from considered_loc
        cur_smaller = set(np.where(self.tb3==border)[0])
        cur_smaller_num = cur_smaller.__len__()
        dist = [0]
        for i in range(considered_num):
            if smaller_area[i] in cur_smaller:
                dist.append(dist[-1]+1)
            else:
                dist.append(dist[-1])
        prob = np.exp(-0.15*np.array([i+cur_smaller_num-2*dist[i] for i in range(dist.__len__())]))
        prob[0] = 0
        prob = prob/prob.sum()
        stop_ite = np.random.choice(considered_loc.__len__()+1, p=prob)
        self.tb3[considered_loc] = border+1
        self.tb3[smaller_area[:int(stop_ite)]] = border
        self.dist = dist
        self.prob = prob
        self.temp_order = temp_order
        self.border = border
        self.smaller_area = smaller_area
        self.considered_loc = considered_loc
        self.myque = myque
        self.stop_ite = stop_ite
        return

    def recombination(self, burst_size):
        cur_tb3 = self.tb3[:]
        Cur_risk_Gain = self.ComputeRiskGain()
        burst = [np.zeros((25), dtype=int), np.zeros((125), dtype=int), Cur_risk_Gain[0], Cur_risk_Gain[1]]
        burst[0][:] = self.tb2[:]
        burst[1][:] = self.tb3[:]
        for ite in range(burst_size):
            self.tb3[:] = cur_tb3
            self.get_topological_sort_tb3()
            Risk, Gain = self.ComputeRiskGain()
            print(Risk,Gain)
            if Risk > self.Epsilon:
                continue
            elif Gain - 0.005 * Risk > burst[-1] - 0.005 * burst[-2]:
                print('here')
                burst[0][:] = self.tb2[:]
                burst[1][:] = self.tb3[:]
                burst[2] = Risk
                burst[3] = Gain
        self.tb2[:] = burst[0][:]
        self.tb3[:] = burst[1][:]
        return



def RandomInitialTP(table2_cur,table3_cur):
    TP_2 = []
    TP_3 = []
    # generate TP_2
    myqueue = [[(0,0)],[],[],[],[]]
    used = [(0,0)]
    for j in range(5):
        while myqueue[j].__len__()>0:
            poptuple = myqueue[j][np.random.choice(len(myqueue[j]))]
            myqueue[j].remove(poptuple)
            TP_2.append(poptuple)
            if poptuple[0]<4:
                temptuple = (poptuple[0]+1,poptuple[1])
                if temptuple not in used:
                    myqueue[table2_cur[temptuple]].append(temptuple)
                    used.append(temptuple)
            if poptuple[1]<4:
                temptuple = (poptuple[0] , poptuple[1]+1)
                if temptuple not in used:
                    myqueue[table2_cur[temptuple]].append(temptuple)
                    used.append(temptuple)
    InitialState_2 = np.array([x[0]*5+x[1] for x in TP_2])
    myqueue = [[(0,0,0)],[],[],[],[]]
    used = [(0,0,0)]
    for j in range(5):
        while myqueue[j].__len__()>0:
            poptuple = myqueue[j][np.random.choice(len(myqueue[j]))]
            myqueue[j].remove(poptuple)
            TP_3.append(poptuple)
            if poptuple[0]<4:
                temptuple = (poptuple[0]+1,poptuple[1],poptuple[2])
                if temptuple not in used:
                    myqueue[table3_cur[temptuple]].append(temptuple)
                    used.append(temptuple)
            if poptuple[1]<4:
                temptuple = (poptuple[0] , poptuple[1]+1,poptuple[2])
                if temptuple not in used:
                    myqueue[table3_cur[temptuple]].append(temptuple)
                    used.append(temptuple)
            if poptuple[2]<4:
                temptuple = (poptuple[0] , poptuple[1],poptuple[2]+1)
                if temptuple not in used:
                    myqueue[table3_cur[temptuple]].append(temptuple)
                    used.append(temptuple)
    InitialState_3 = np.array([x[0]*25+x[1]*5+x[2] for x in TP_3])
    return InitialState_2,InitialState_3


class TPSort:
    def __init__(self, Xlist, Risk_table, Gain_table, R, Epsilon, penalty_scale):
        self.index_2 = None
        self.X4_input = Xlist[0]
        self.X3a_input = Xlist[1]
        self.X3b_input = Xlist[2]
        self.X3c_input = Xlist[3]
        self.X2a_input = Xlist[4]
        self.X2b_input = Xlist[5]
        self.X2c_input = Xlist[6]
        self.X2d_input = Xlist[7]
        self.X2e_input = Xlist[8]
        self.X2f_input = Xlist[9]
        self.relation3 = self.get_relation3()
        self.relation2 = self.get_relation2()
        self.State2, self.State3, self.Break2, self.Break3 = self.InitializationState()
        self.Risk_table = Risk_table
        self.Gain_table = Gain_table
        self.R = R
        self.N = self.X4_input.shape[0]
        self.Epsilon = Epsilon
        self.penalty_scale = penalty_scale
        self.tb3 = np.zeros((125),dtype=int)
        self.tb2 = np.zeros((25), dtype=int)
        self.tb3[:] = table3_array[:]
        self.tb2[:] = table2_array[:]
        self.cur_4 = table3_array[
                            self.X4_input.iloc[:, 0] * 25 + self.X4_input.iloc[:, 1] * 5 + self.X4_input.iloc[:, 2] * 1]

    def predict(self,X_input):
        # X_input is the 19 dimensional input of 0-4 scores
        dt = pd.DataFrame(X_input, dtype=int)
        dt['1f1g'] = table2_array[dt.iloc[:, 5].values * 5 + dt.iloc[:, 6].values ]
        dt['1l1k'] = table2_array[dt.iloc[:, 11].values * 5 + dt.iloc[:, 10].values ]
        New_2a = self.tb2[dt.iloc[:, 0].values * 5 + dt.iloc[:,1].values * 1]
        New_2b = self.tb3[dt.iloc[:, 4].values * 25 + dt.iloc[:,3].values*5+ dt.iloc[:,19].values * 1]
        New_2c = self.tb2[dt.iloc[:, 7].values * 5 + dt.iloc[:,8].values * 1]
        New_2d = self.tb3[dt.iloc[:, 9].values * 25 + dt.iloc[:,20].values*5+ dt.iloc[:,6].values * 1]
        New_2e = self.tb3[dt.iloc[:, 14].values * 25 + dt.iloc[:,13].values*5+ dt.iloc[:,15].values * 1]
        New_2f = self.tb3[dt.iloc[:, 17].values * 25 + dt.iloc[:,16].values*5+ dt.iloc[:,18].values * 1]
        New_3a = self.tb3[New_2a*25+New_2b*5+dt.iloc[:,2].values] #2a,2b,1c
        New_3b = self.tb3[New_2d*25+New_2c*5+dt.iloc[:,12].values] #2d,2c,1m
        New_3c = self.tb2[New_2e*5+New_2f] #2e,2f
        New_4 = self.tb3[New_3a * 25 + New_3b * 5 + New_3c]
        return New_4

    def ComputeRiskGain(self):
        New_2a = self.tb2[self.X2a_input['V1']*5+self.X2a_input['V2']]
        New_2b = self.tb3[self.X2b_input['V1']*25+self.X2b_input['V2']*5+self.X2b_input['V3']]
        New_2c = self.tb2[self.X2c_input['V1']*5+self.X2c_input['V2']]
        New_2d = self.tb3[self.X2d_input['V1']*25+self.X2d_input['V2']*5+self.X2d_input['V3']]
        New_2e = self.tb3[self.X2e_input['V1']*25+self.X2e_input['V2']*5+self.X2e_input['V3']]
        New_2f = self.tb3[self.X2f_input['V1']*25+self.X2f_input['V2']*5+self.X2f_input['V3']]
        New_3a = self.tb3[New_2a*25+New_2b*5+self.X3a_input['V3']] #2a,2b,1c
        New_3b = self.tb3[New_2d*25+New_2c*5+self.X3b_input['V3']] #2d,2c,1m
        New_3c = self.tb2[New_2e*5+New_2f] #2e,2f
        New_4 = self.tb3[New_3a*25+New_3b*5+New_3c]
        tempindex = np.stack(
            [self.cur_4, New_4, np.array(range(self.N))], axis=1)
        risk = self.Risk_table[tuple(zip(*tempindex))].sum()
        gain = self.Gain_table[tuple(zip(*tempindex))].sum()
        return risk/self.N,gain/self.N

    def PDP(self, X_input):
        # X_input is the 19 dimensional input of 0-4 scores
        temp = np.empty_like(X_input)
        temp[:] = X_input[:]
        result = np.empty(shape=(19, 5))
        for i in range(19):
            for j in range(5):
                temp[:,i] = j
                result[i,j] = self.predict(temp).mean()
                temp[:,i] = X_input[:,i]
        return result

    def InitializationState(self):
        # Initial_state2 = np.array([0,1,5,2,6,10,3,7,11,15,4,8,12,16,20,9,13,17,21,14,18,22,19,23,24])
        # Initial_state3 = np.array([0,1,5,25,2,6,10,3,7,11,15,4,8,20,
        #                            26,30,50,27,31,35,51,55,75,12,16,28,32,36,40,52,56,76,100,9,13,17,21,29,33,37,41,45,14,18,22,34,38,46,19,23,24,
        #                            60,80,53,57,61,65,77,81,85,101,105,42,54,58,62,66,70,78,82,86,90,102,106,110,39,43,47,59,63,67,71,79,83,87,103,107,111,44,48,64,72,49,
        #                            91,95,115,68,84,88,92,96,104,108,112,116,120,69,73,89,93,97,109,113,117,121,74,94,98,114,118,
        #                            122,99,119,123,124])
        Initial_state2, Initial_state3 = RandomInitialTP(table2_cur,table3_cur)
        return Initial_state2, Initial_state3, np.array([3,10,19,24]), np.array([14,51,93,120])

    def showplot3(self, policy):
        DG = graphviz.Digraph(comment='Partial Order', format='png')
        table2 = [[str(25 * i + 5 * j + k)] for i, j, k in product(range(5), range(5), range(5))]
        colors = ['red', 'orange', 'green', 'blue', 'black']
        colortable = [colors[policy[25 * i + 5 * j + k] - 1] for i, j, k in product(range(5), range(5), range(5))]
        for i in range(125):
            x = table2[i]
            xcolor = colortable[i]
            DG.node(''.join(x), color=xcolor)
        for i, j, k in product(range(5), range(5), range(4)):
            DG.edge(''.join(table2[25 * i + 5 * j + k]), ''.join(table2[25 * i + 5 * j + (k + 1)]))
        for i, j, k in product(range(5), range(4), range(5)):
            DG.edge(''.join(table2[25 * i + 5 * j + k]), ''.join(table2[25 * i + 5 * (j + 1) + k]))
        for i, j, k in product(range(4), range(5), range(5)):
            DG.edge(''.join(table2[25 * i + 5 * j + k]), ''.join(table2[25 * (i + 1) + 5 * j + k]))
        return DG

    def showplot2(self, policy):
        DG = graphviz.Digraph(comment='Partial Order', format='png')
        table2 = [[str(5 * i + 1 * j )] for i, j in product(range(5), range(5))]
        colors = ['red', 'orange', 'green', 'blue', 'black']
        colortable = [colors[policy[5 * i + j ] - 1] for i, j in product(range(5), range(5))]
        for i in range(25):
            x = table2[i]
            xcolor = colortable[i]
            DG.node(''.join(x), color=xcolor)
        for i, j in product(range(5), range(4)):
            DG.edge(''.join(table2[5 * i +  j ]), ''.join(table2[5 * i + j + 1]))
        for i, j in product(range(4), range(5)):
            DG.edge(''.join(table2[5 * i +  j ]), ''.join(table2[5 * i + j + 5]))
        return DG

    def get_relation3(self):
        relation_3 = np.zeros(shape=(125, 125), dtype=int)
        for i1, i2, i3 in product(range(5), range(5), range(5)):
            for j1, j2, j3 in product(range(i1 + 1), range(i2 + 1), range(i3 + 1)):
                relation_3[25 * i1 + 5 * i2 + i3, 25 * j1 + 5 * j2 + j3] = 1
            for j1, j2, j3 in product(range(i1, 5), range(i2, 5), range(i3, 5)):
                relation_3[25 * i1 + 5 * i2 + i3, 25 * j1 + 5 * j2 + j3] = -1
            relation_3[25 * i1 + 5 * i2 + i3, 25 * i1 + 5 * i2 + i3] = 0
        return relation_3

    def get_relation2(self):
        relation_3 = np.zeros(shape=(25, 25), dtype=int)
        for  i2, i3 in product(range(5), range(5)):
            for  j2, j3 in product( range(i2 + 1), range(i3 + 1)):
                relation_3[ 5 * i2 + i3,  5 * j2 + j3] = 1
            for j2, j3 in product(range(i2, 5), range(i3, 5)):
                relation_3[ 5 * i2 + i3,  5 * j2 + j3] = -1
            relation_3[5 * i2 + i3, 5 * i2 + i3] = 0
        return relation_3

    def recombination3(self, State):
        policy = np.empty(shape=(125), dtype=int)
        breakpoint = sorted(np.random.choice(range(125), 4, replace=False))
        policy[State[:breakpoint[0]]] = 0
        policy[State[breakpoint[0]:breakpoint[1]]] = 1
        policy[State[breakpoint[1]:breakpoint[2]]] = 2
        policy[State[breakpoint[2]:breakpoint[3]]] = 3
        policy[State[breakpoint[3]:]] = 4
        return policy

    def recombination2(self, State):
        policy = np.empty(shape=(25), dtype=int)
        breakpoint = sorted(np.random.choice(range(25), 4, replace=False))
        policy[State[:breakpoint[0]]] = 0
        policy[State[breakpoint[0]:breakpoint[1]]] = 1
        policy[State[breakpoint[1]:breakpoint[2]]] = 2
        policy[State[breakpoint[2]:breakpoint[3]]] = 3
        policy[State[breakpoint[3]:]] = 4
        return policy

    def recombination(self):
        table3 = np.empty(shape=(125), dtype=int)
        table2 = np.empty(shape=(25), dtype=int)

        table2[self.State2[:self.Break2[0]]] = 0
        table2[self.State2[self.Break2[0]:self.Break2[1]]] = 1
        table2[self.State2[self.Break2[1]:self.Break2[2]]] = 2
        table2[self.State2[self.Break2[2]:self.Break2[3]]] = 3
        table2[self.State2[self.Break2[3]:]] = 4
        
        table3[self.State3[:self.Break3[0]]] = 0
        table3[self.State3[self.Break3[0]:self.Break3[1]]] = 1
        table3[self.State3[self.Break3[1]:self.Break3[2]]] = 2
        table3[self.State3[self.Break3[2]:self.Break3[3]]] = 3
        table3[self.State3[self.Break3[3]:]] = 4
        
        return table2, table3

    def UniformSample(self, Itenum):
        Trajectory2 = np.empty(shape=(Itenum,25), dtype=int)
        Trajectory3 = np.empty(shape=(Itenum,125), dtype=int)
        myrandomseed2 = np.random.choice(range(48), size=Itenum)
        myrandomseed3 = np.random.choice(range(248), size=Itenum)
        for ite in range(Itenum):
            Trajectory2[ite, :] = self.State2[:]
            Trajectory3[ite, :] = self.State3[:]
            # update State2
            if myrandomseed2[ite] >= 24:
                continue
            else:
                a = self.State2[myrandomseed2[ite]]
                b = self.State2[myrandomseed2[ite] + 1]
                if self.relation2[a, b] < 0:
                    continue
                else:
                    self.State2[myrandomseed2[ite] + 1] = a
                    self.State2[myrandomseed2[ite]] = b

            if myrandomseed3[ite] >= 124:
                continue
            else:
                a = self.State3[myrandomseed3[ite]]
                b = self.State3[myrandomseed3[ite] + 1]
                if self.relation3[a, b] < 0:
                    continue
                else:
                    self.State3[myrandomseed3[ite] + 1] = a
                    self.State3[myrandomseed3[ite]] = b
        return Trajectory2, Trajectory3

    def RansomWalkOptim(self, Itenum):
        result = np.empty(shape=(Itenum, 2))
        myrandomseed2 = np.random.choice(range(48), size=Itenum)
        myrandomseed3 = np.random.choice(range(248), size=Itenum)
        for ite in range(Itenum):
            # update State2
            Risk, Gain = self.ComputeRiskGain()
            result[ite,:] = Risk, Gain
            if myrandomseed2[ite] >= 24:
                continue
            else:
                a = self.State2[myrandomseed2[ite]]
                b = self.State2[myrandomseed2[ite] + 1]
                if self.relation2[a, b] < 0:
                    continue
                else:
                    self.State2[myrandomseed2[ite] + 1] = a
                    self.State2[myrandomseed2[ite]] = b

            if myrandomseed3[ite] >= 124:
                continue
            else:
                a = self.State3[myrandomseed3[ite]]
                b = self.State3[myrandomseed3[ite] + 1]
                if self.relation3[a, b] < 0:
                    continue
                else:
                    self.State3[myrandomseed3[ite] + 1] = a
                    self.State3[myrandomseed3[ite]] = b
            self.tb2 = self.recombination2(self.State2)
            self.tb3 = self.recombination3(self.State3)
        return result

    def ShortBurstOptim(self, Itenum, Burst_size):
        result = np.empty(shape=(Itenum, 2))
        myrandomseed2 = np.random.choice(range(32), size=Itenum*Burst_size)
        myrandomseed3 = np.random.choice(range(160), size=Itenum*Burst_size)
        changed_num = 0
        for ite in range(Itenum):
            # update State2
            cur_risk_gain = self.ComputeRiskGain()
            burst = [np.zeros((25), dtype=int), np.zeros((125), dtype=int),np.zeros((25), dtype=int), np.zeros((125), dtype=int), cur_risk_gain[0], cur_risk_gain[1]]
            burst[0][:] = self.tb2[:]
            burst[1][:] = self.tb3[:]
            burst[2][:] = self.State2[:]
            burst[3][:] = self.State3[:]
            for k in range(Burst_size):
                Is_changed = False
                if myrandomseed2[ite*Burst_size+k] < 24:
                    a = self.State2[myrandomseed2[ite*Burst_size+k]]
                    b = self.State2[myrandomseed2[ite*Burst_size+k] + 1]
                    if self.relation2[a, b] >= 0:
                        self.State2[myrandomseed2[ite*Burst_size+k] + 1] = a
                        self.State2[myrandomseed2[ite*Burst_size+k]] = b
                        Is_changed = True
                if myrandomseed3[ite*Burst_size+k] < 124:
                    a = self.State3[myrandomseed3[ite*Burst_size+k]]
                    b = self.State3[myrandomseed3[ite*Burst_size+k] + 1]
                    if self.relation3[a, b] >= 0:
                        self.State3[myrandomseed3[ite*Burst_size+k] + 1] = a
                        self.State3[myrandomseed3[ite*Burst_size+k]] = b
                        Is_changed = True
                if Is_changed:
                    changed_num+=1
                self.tb2 = self.recombination2(self.State2)
                self.tb3 = self.recombination3(self.State3)
                risk,gain = self.ComputeRiskGain()
                if risk < self.Epsilon and gain > burst[-1]:
                    burst[0][:] = self.tb2[:]
                    burst[1][:] = self.tb3[:]
                    burst[2][:] = self.State2
                    burst[3][:] = self.State3
                    burst[4] = risk
                    burst[5] = gain
            self.State2 = burst[2]
            self.State3 = burst[3]
            self.tb2 = burst[0]
            self.tb3 = burst[1]
            result[ite,:] = burst[4],burst[5]
        print(changed_num)
        return result

    def ShortBurstMCMCOptim(self, Itenum, Burst_size, regularization=0):
        result = np.empty(shape=(Itenum, 2))
        myrandomseed2 = np.random.choice(range(32), size=Itenum*Burst_size)
        myrandomseed3 = np.random.choice(range(160), size=Itenum*Burst_size)
        myrandomseed_brkpt = np.random.choice(range(4), size=2*Itenum*Burst_size).reshape(2,-1)
        myrandomseed_direction = np.random.choice([-1,0,1], size=2*Itenum*Burst_size).reshape(2,-1)
        changed_num = 0
        for ite in range(Itenum):
            # update State2
            cur_risk_gain = self.ComputeRiskGain()
            burst = [np.zeros((25), dtype=int), np.zeros((125), dtype=int),np.zeros((4), dtype=int), np.zeros((4), dtype=int), cur_risk_gain[0], cur_risk_gain[1]]
            burst[0][:] = self.State2[:]
            burst[1][:] = self.State3[:]
            burst[2][:] = self.Break2[:]
            burst[3][:] = self.Break3[:]
            for k in range(Burst_size):
                Is_changed = False
                if myrandomseed2[ite*Burst_size+k] < 24:
                    a = self.State2[myrandomseed2[ite*Burst_size+k]]
                    b = self.State2[myrandomseed2[ite*Burst_size+k] + 1]
                    if self.relation2[a, b] >= 0:
                        self.State2[myrandomseed2[ite*Burst_size+k] + 1] = a
                        self.State2[myrandomseed2[ite*Burst_size+k]] = b
                        Is_changed = True
                if myrandomseed3[ite*Burst_size+k] < 124:
                    a = self.State3[myrandomseed3[ite*Burst_size+k]]
                    b = self.State3[myrandomseed3[ite*Burst_size+k] + 1]
                    if self.relation3[a, b] >= 0:
                        self.State3[myrandomseed3[ite*Burst_size+k] + 1] = a
                        self.State3[myrandomseed3[ite*Burst_size+k]] = b
                        Is_changed = True
                # update break point
                if self.Break2[myrandomseed_brkpt[0,ite*Burst_size+k]] + myrandomseed_direction[0,ite*Burst_size+k] not in self.Break2:
                    self.Break2[myrandomseed_brkpt[0, ite * Burst_size + k]] += myrandomseed_direction[0,ite*Burst_size+k]
                    Is_changed = True
                if self.Break3[myrandomseed_brkpt[1,ite*Burst_size+k]] + myrandomseed_direction[1,ite*Burst_size+k] not in self.Break3:
                    self.Break3[myrandomseed_brkpt[1, ite * Burst_size + k]] += myrandomseed_direction[1,ite*Burst_size+k]
                    Is_changed = True
                if Is_changed:
                    changed_num+=1
                self.tb2, self.tb3 = self.recombination()
                risk,gain = self.ComputeRiskGain()
                if risk < self.Epsilon and gain - regularization*risk > burst[-1] - regularization * burst[-2]:
                    burst[0][:] = self.State2[:]
                    burst[1][:] = self.State3[:]
                    burst[2][:] = self.Break2[:]
                    burst[3][:] = self.Break3[:]
                    burst[4] = risk
                    burst[5] = gain
            self.State2 = burst[0]
            self.State3 = burst[1]
            self.Break2 = burst[2]
            self.Break3 = burst[3]
            self.tb2, self.tb3 = self.recombination()
            result[ite,:] = burst[4],burst[5]
        print(changed_num)
        return result

    def ShortBurstMCMCOptim_sa(self, Itenum, Burst_size, regularization=0):
        result = np.empty(shape=(Itenum, 2))
        myrandomseed2 = np.random.choice(range(32), size=Itenum*Burst_size)
        myrandomseed3 = np.random.choice(range(160), size=Itenum*Burst_size)
        myrandomseed_brkpt = np.random.choice(range(4), size=2*Itenum*Burst_size).reshape(2,-1)
        myrandomseed_direction = np.random.choice([-3,-2,-1,0,1,2,3], size=2*Itenum*Burst_size).reshape(2,-1)
        changed_num = 0
        for ite in range(Itenum):
            # update State2
            burst_success = False
            cur_risk_gain = self.ComputeRiskGain()
            burst = [np.zeros((25), dtype=int), np.zeros((125), dtype=int),np.zeros((4), dtype=int), np.zeros((4), dtype=int), cur_risk_gain[0], cur_risk_gain[1]]
            burst[0][:] = self.State2[:]
            burst[1][:] = self.State3[:]
            burst[2][:] = self.Break2[:]
            burst[3][:] = self.Break3[:]
            for k in range(Burst_size):
                Is_changed = False
                if myrandomseed2[ite*Burst_size+k] < 24:
                    a = self.State2[myrandomseed2[ite*Burst_size+k]]
                    b = self.State2[myrandomseed2[ite*Burst_size+k] + 1]
                    if self.relation2[a, b] >= 0:
                        self.State2[myrandomseed2[ite*Burst_size+k] + 1] = a
                        self.State2[myrandomseed2[ite*Burst_size+k]] = b
                        Is_changed = True
                if myrandomseed3[ite*Burst_size+k] < 124:
                    a = self.State3[myrandomseed3[ite*Burst_size+k]]
                    b = self.State3[myrandomseed3[ite*Burst_size+k] + 1]
                    if self.relation3[a, b] >= 0:
                        self.State3[myrandomseed3[ite*Burst_size+k] + 1] = a
                        self.State3[myrandomseed3[ite*Burst_size+k]] = b
                        Is_changed = True
                # update break point
                if self.Break2[myrandomseed_brkpt[0,ite*Burst_size+k]] + myrandomseed_direction[0,ite*Burst_size+k] not in self.Break2:
                    self.Break2[myrandomseed_brkpt[0, ite * Burst_size + k]] += myrandomseed_direction[0,ite*Burst_size+k]
                if self.Break3[myrandomseed_brkpt[1,ite*Burst_size+k]] + myrandomseed_direction[1,ite*Burst_size+k] not in self.Break3:
                    self.Break3[myrandomseed_brkpt[1, ite * Burst_size + k]] += myrandomseed_direction[1,ite*Burst_size+k]
                if Is_changed:
                    changed_num+=1
                self.tb2, self.tb3 = self.recombination()
                risk,gain = self.ComputeRiskGain()
                if gain - regularization*risk > burst[-1] - regularization * burst[-2]:
                    burst[0][:] = self.State2[:]
                    burst[1][:] = self.State3[:]
                    burst[2][:] = self.Break2[:]
                    burst[3][:] = self.Break3[:]
                    burst[4] = risk
                    burst[5] = gain
                    burst_success = True
            self.State2 = burst[0]
            self.State3 = burst[1]
            self.Break2 = burst[2]
            self.Break3 = burst[3]
            self.tb2, self.tb3 = self.recombination()
            if True is True:
                regularization *= 1.005
                print(regularization)
            result[ite,:] = burst[4],burst[5]
        print(changed_num)
        return result

